# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Yuze (dingyiwei@stu.xmu.edu.cn)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import pprint
import shutil

import numpy as np
import random

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from tensorboardX import SummaryWriter

import _init_paths
import models
from utils import ddp_opx
from core.trainer import Trainer
from core.loss import JointsMSELoss
from dataset.dataManager import DataManager
from config import cfg, update_config, get_args_parser
from utils.utils import create_logger, get_optimizer, save_checkpoint, merge_dicts


def main():
    args = get_args_parser()
    update_config(cfg, args)

    ddp_opx.init_distributed_mode(args)
    device = torch.device(args.device)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    seed = args.seed + ddp_opx.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # >>>>>>>>>>>>>>>>>>>>>>>>> record log <<<<<<<<<<<<<<<<<<<<<<<<<
    writer_dict = None
    logger, final_output_dir, tb_log_dir = create_logger(cfg, args.cfg, ddp_opx.get_rank(), 'val' if args.eval else 'train')
    if ddp_opx.is_main_process():
        logger.info(pprint.pformat(args))
        logger.info(cfg)

        writer_dict = {
            'writer': SummaryWriter(log_dir=tb_log_dir),
            'train_global_steps': 0,
            'valid_global_steps': 0,
        }

        # copy model file
        this_dir = os.path.dirname(__file__)
        shutil.copy2(os.path.join(this_dir, 'lib/models', cfg.MODEL.NAME + '.py'), final_output_dir)

    trainer = Trainer(args, cfg)

    # >>>>>>>>>>>>>>>>>>>>>>>>> Data <<<<<<<<<<<<<<<<<<<<<<<<<
    dataManager = DataManager(args, cfg)
    train_loader = dataManager.get_dataloader('train')
    valid_loader = dataManager.get_dataloader('val')

    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    best_perf = 0.0
    best_model = False
    last_epoch = -1
    
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
    ).cuda()
    
    # >>>>>>>>>>>>>>>>>>>>>>>>> model <<<<<<<<<<<<<<<<<<<<<<<<<
    if cfg.DATASET.DATASET == 'vcoco':
        object_to_target = train_loader.dataset.hoi_data.object_to_action 
        # object_to_target = list(train_loader.dataset.hoi_data.object_to_action.values())
        model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
            cfg, object_to_target, is_train=True
        )
    else:
        model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
            cfg, is_train=True
        )

    model.cuda()

    # one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
    # all_records = ddp_opx.all_gather(one_record)
    # if ddp_opx.is_main_process():
    #     logger.info('=> eval model of {}'.format(cfg.TEST.MODEL_FILE))
    #     all_records = merge_dicts(all_records)
    #     trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir)
    # return

    optimizer = get_optimizer(cfg, model.parameters())
    for (name, p) in model.named_parameters():
        if p.requires_grad:
            print(f'optimizer finetune {name}')
    
    optimizer = get_optimizer(cfg, filter(lambda p: p.requires_grad, model.parameters()))

    # >>>>>>>>>>>>>>>>>>>>>>>>> eval <<<<<<<<<<<<<<<<<<<<<<<<<
    if args.eval:
        ckpt_state_dict = torch.load(cfg.TEST.MODEL_FILE, map_location=torch.device('cpu'))
        model.load_state_dict(ckpt_state_dict, strict=False)
        one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
        all_records = ddp_opx.all_gather(one_record)
        if ddp_opx.is_main_process():
            logger.info('=> eval model of {}'.format(cfg.TEST.MODEL_FILE))
            all_records = merge_dicts(all_records)
            trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir)
        return

    # >>>>>>>>>>>>>>>>>>>>>>>>> Resume <<<<<<<<<<<<<<<<<<<<<<<<<
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']

        if ddp_opx.is_main_process():
            logger.info("=> Auto resume loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch']))
            writer_dict['train_global_steps'] = checkpoint['train_global_steps']
            writer_dict['valid_global_steps'] = checkpoint['valid_global_steps']

        model.load_state_dict(checkpoint['best_state_dict'], strict=False)
        optimizer.load_state_dict(checkpoint['optimizer'])

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, cfg.TRAIN.END_EPOCH, eta_min=cfg.TRAIN.LR_END, last_epoch=last_epoch)
    
    if args.distributed:
        # find_unused_parameters = False if some parameters in the model are frozen else True
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
    else:
        # for single gpu save model.module.state_dict()
        model = torch.nn.DataParallel(model, device_ids=[args.gpu])
    
    # >>>>>>>>>>>>>>>>>>>>>>>>> begin to train <<<<<<<<<<<<<<<<<<<<<<<<<
    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        if ddp_opx.is_main_process():
            logger.info("=> current learning rate is {:.6f}".format(lr_scheduler.get_last_lr()[0]))

        trainer.train_one_epoch(train_loader, model, criterion, optimizer, epoch,
            final_output_dir, writer_dict)

        # evaluate on validation set
        one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
        all_records = ddp_opx.all_gather(one_record)

        if ddp_opx.is_main_process():
            all_records = merge_dicts(all_records)
            perf_indicator = trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir, writer_dict)

        lr_scheduler.step()

        if ddp_opx.is_main_process() and perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        if ddp_opx.is_main_process():
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint({
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
                'train_global_steps': writer_dict['train_global_steps'],
                'valid_global_steps': writer_dict['valid_global_steps'],
            }, best_model, final_output_dir)

    if ddp_opx.is_main_process():
        final_model_state_file = os.path.join(
            final_output_dir, 'final_state.pth'
        )
        logger.info('=> saving final model state to {}'.format(
            final_model_state_file)
        )
        torch.save(model.module.state_dict(), final_model_state_file)
        writer_dict['writer'].close()

    ddp_opx.cleanup()
    print("#####\nTraining Done!\n#####")


if __name__ == '__main__':
    main()